
package com.jstarcraft.ai.jsat.classifiers.calibration;

import java.util.Collections;
import java.util.List;

import com.jstarcraft.ai.jsat.classifiers.ClassificationDataSet;
import com.jstarcraft.ai.jsat.classifiers.Classifier;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.classifiers.DataPointPair;
import com.jstarcraft.ai.jsat.parameters.Parameterized;
import com.jstarcraft.ai.jsat.parameters.Parameter.ParameterHolder;

/**
 * This abstract class provides the frame work for an algorithm to perform
 * probability calibration based on the outputs of a base learning algorithm for
 * binary classification problems. <br>
 * <br>
 * Calibration can be performed directly on output values, though it may cause
 * over-fitting. For this reason, the {@link CalibrationMode} may be set to an
 * alternative method. <br>
 * <br>
 * The parameters include the calibration parameters, and any parameters that
 * would be returned by the base model.
 * 
 * @author Edward Raff
 */
public abstract class BinaryCalibration implements Classifier, Parameterized {

    private static final long serialVersionUID = 2356311701854978890L;
    /**
     * The base classifier to train and calibrate the outputs of
     */
    @ParameterHolder
    protected BinaryScoreClassifier base;
    /**
     * The number of CV folds
     */
    protected int folds = 3;
    /**
     * The proportion of the data set to hold out for calibration
     */
    protected double holdOut = 0.3;
    /**
     * The calibration mode to use
     */
    protected CalibrationMode mode;

    /**
     * Creates a new Binary Calibration object
     * 
     * @param base the base learning algorithm
     * @param mode the calibration mode to use
     */
    public BinaryCalibration(BinaryScoreClassifier base, CalibrationMode mode) {
        this.base = base;
        setCalibrationMode(mode);
    }

    /**
     * Controls how the scores are obtained for producing a "training set" to
     * calibrate the output of the underlying model.
     */
    public static enum CalibrationMode {
        /**
         * The naive methods trains the classifier on the whole data set, and then
         * produces the scores for each training point. This may cause over fitting.
         */
        NAIVE,
        /**
         * The model will be trained by cross validation, using the specified number of
         * {@link #setCalibrationFolds(int) }. The default is 3 folds, where the
         * classifier will be trained on the folds not in the set, and then produce
         * scores for the unobserved test points in the held out fold.
         */
        CV,
        /**
         * The model will have a random {@link #setCalibrationHoldOut(double) fraction}
         * of the data set held out, and trained on the rest of the data. The scores
         * will then be produced for the held out data and used for calibration.
         */
        HOLD_OUT,
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        double[] deci = new double[dataSet.size()];// array of SVM decision values
        boolean[] label = new boolean[deci.length];// array of booleans: is the example labeled +1?
        int len = label.length;

        if (mode == CalibrationMode.CV) {
            List<ClassificationDataSet> foldList = dataSet.cvSet(folds);
            int pos = 0;
            for (int i = 0; i < foldList.size(); i++) {
                ClassificationDataSet test = foldList.get(i);
                ClassificationDataSet train = ClassificationDataSet.comineAllBut(foldList, i);
                base.train(train, parallel);

                for (int j = 0; j < test.size(); j++) {
                    deci[pos] = base.getScore(test.getDataPoint(j));
                    label[pos] = test.getDataPointCategory(j) == 1;
                    pos++;
                }
            }

            base.train(dataSet, parallel);
        } else if (mode == CalibrationMode.HOLD_OUT) {
            List<DataPointPair<Integer>> wholeSet = dataSet.getAsDPPList();
            Collections.shuffle(wholeSet);

            int splitMark = (int) (wholeSet.size() * (1 - holdOut));
            ClassificationDataSet train = new ClassificationDataSet(wholeSet.subList(0, splitMark), dataSet.getPredicting());
            ClassificationDataSet test = new ClassificationDataSet(wholeSet.subList(splitMark, wholeSet.size()), dataSet.getPredicting());

            base.train(train, parallel);
            for (int i = 0; i < test.size(); i++) {
                deci[i] = base.getScore(test.getDataPoint(i));
                label[i] = test.getDataPointCategory(i) == 1;
            }

            len = test.size();

            base.train(dataSet, parallel);
        } else {
            base.train(dataSet, parallel);

            for (int i = 0; i < len; i++) {
                DataPoint dp = dataSet.getDataPoint(i);
                deci[i] = base.getScore(dp);
                label[i] = dataSet.getDataPointCategory(i) == 1;
            }
        }

        calibrate(label, deci, len);
    }

    /**
     * This method perform the model calibration on the outputs verse the class
     * labels.
     * 
     * @param label  the set of labels, where {@code true} indicates the positive
     *               class label, and {@code false} indicates the negative class
     *               label.
     * @param scores the score associated with each label from the learning
     *               algorithm.
     * @param len    the number of values (from zero) of the label and scores array
     *               to use. This value may be less than the actual array size
     */
    abstract protected void calibrate(boolean[] label, double[] scores, final int len);

    /**
     * If the calibration mode is set to {@link CalibrationMode#CV}, this controls
     * how many folds of cross validation will be used. The default is 3.
     * 
     * @param folds the number of cross validation folds to perform
     */
    public void setCalibrationFolds(int folds) {
        if (folds < 1)
            throw new IllegalArgumentException("Folds must be a positive value, not " + folds);
        this.folds = folds;
    }

    /**
     * Returns the number of cross validation folds to use
     * 
     * @return the number of cross validation folds to use
     */
    public int getCalibrationFolds() {
        return folds;
    }

    /**
     * If the calibration mode is set to {@link CalibrationMode#HOLD_OUT}, this what
     * portion of the data set is randomly selected to be the hold out set. The
     * default is 0.3.
     * 
     * @param holdOut the portion in (0, 1) to hold out
     */
    public void setCalibrationHoldOut(double holdOut) {
        if (Double.isNaN(holdOut) || holdOut <= 0 || holdOut >= 1)
            throw new IllegalArgumentException("HoldOut must be in (0, 1), not " + holdOut);
        this.holdOut = holdOut;
    }

    /**
     * Returns the portion of the data set that will be held out for calibration
     * 
     * @return the portion of the data set that will be held out for calibration
     */
    public double getCalibrationHoldOut() {
        return holdOut;
    }

    /**
     * Sets which calibration mode will be used during training
     * 
     * @param mode the calibration mode to use during training.
     */
    public void setCalibrationMode(CalibrationMode mode) {
        this.mode = mode;
    }

    /**
     * Returns the calibration mode used during training
     * 
     * @return the calibration mode used during training
     */
    public CalibrationMode getCalibrationMode() {
        return mode;
    }

    @Override
    abstract public BinaryCalibration clone();

}
